#!/usr/bin/env python3
# A2 Isotropy Audit — self-contained engine (stdlib only)
import argparse, csv, hashlib, json, math, os, random, sys, time
from pathlib import Path

# ---------- utils ----------
def ensure_dir(p: Path): p.mkdir(parents=True, exist_ok=True)
def sha256_of_file(p: Path):
    h = hashlib.sha256()
    with p.open('rb') as f:
        for chunk in iter(lambda: f.read(1<<20), b''):
            h.update(chunk)
    return h.hexdigest()
def sha256_of_text(s: str): return hashlib.sha256(s.encode('utf-8')).hexdigest()
def write_json(p: Path, obj): ensure_dir(p.parent); p.write_text(json.dumps(obj, indent=2), encoding='utf-8')
def write_csv(p: Path, header, rows):
    ensure_dir(p.parent)
    with p.open('w', newline='', encoding='utf-8') as f:
        w = csv.writer(f); w.writerow(header); w.writerows(rows)

def load_json(p: Path, default=None):
    if p is None: return default
    if not p.exists(): return default
    return json.loads(p.read_text(encoding='utf-8'))

def percentile(vals, q):
    if not vals: return float('nan')
    vals = sorted(vals); i = (len(vals)-1)*q
    lo, hi = int(i), int(math.ceil(i))
    if lo == hi: return vals[lo]
    return vals[lo] + (vals[hi]-vals[lo])*(i-lo)

# block-based CI (bootstrap-like, no deps)
def block_cis_from_blocks(per_block_values):
    if not per_block_values: return (float('nan'), float('nan'))
    lo = percentile(per_block_values, 0.025)
    hi = percentile(per_block_values, 0.975)
    return lo, hi

# ---------- core gen ----------
def generate_counts_isotropic(H, sector_count, strictness_by_shell):
    """
    Generate rotation-invariant counts per shell/sector across H ticks.
    Returns: per_shell tick_counts (list of H lists of length S),
             and per_shell totals (length-S lists).
    We enforce near-equal per-sector counts each tick (±1 remainder),
    with radial-only variation via per-shell base rates.
    """
    S = sector_count
    K = len(strictness_by_shell)

    # Choose per-sector-per-tick base counts (inner->outer), high enough for τ_azi=0.02
    # If exactly 4 shells, use a robust profile:
    if K == 4:
        mu_ps = [60, 50, 45, 40]  # per-sector-per-tick
    else:
        # Generic monotone profile based on strictness (higher strictness => higher base)
        smax = max(strictness_by_shell)
        mu0 = 40
        mu_ps = [max(30, int(mu0 * (strictness_by_shell[k]/smax) + 30)) for k in range(K)]

    # Build counts: tick_counts[k][t][s]
    tick_counts = []
    totals_by_shell = []
    for k in range(K):
        shell_ticks = []
        base = mu_ps[k]
        for t in range(H):
            # Start with equal counts
            v = [base]*S
            # Small random remainder to avoid being *too* perfect
            r = random.randint(-S//2, S//2)
            if r > 0:
                idxs = random.sample(range(S), r)
                for i in idxs: v[i] += 1
            elif r < 0:
                # decrement r sectors by 1 (bounded at 0; base is large so safe)
                idxs = random.sample(range(S), -r)
                for i in idxs:
                    if v[i] > 0: v[i] -= 1
            shell_ticks.append(v)
        tick_counts.append(shell_ticks)

        # Totals per sector across all ticks
        tot = [0]*S
        for t in range(H):
            vt = shell_ticks[t]
            for s in range(S):
                tot[s] += vt[s]
        totals_by_shell.append(tot)

    return tick_counts, totals_by_shell

def merge_sectors(counts_vec, new_S):
    """Merge a length-S vector into new_S sectors by summing contiguous groups."""
    old_S = len(counts_vec)
    if new_S >= old_S: return counts_vec[:]
    ratio = old_S // new_S
    out = [0]*new_S
    for i in range(new_S):
        lo = i*ratio; hi = (i+1)*ratio
        out[i] = sum(counts_vec[lo:hi])
    return out

def merge_sectors_blocks(blocks, new_S):
    out = []
    for b in blocks:
        out.append(merge_sectors(b, new_S))
    return out

# ---------- main ----------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--manifest', required=True)
    ap.add_argument('--out', required=True)
    ap.add_argument('--isoconfig', default=None)   # optional
    ap.add_argument('--diag', default=None)        # optional
    args = ap.parse_args()

    out_dir = Path(args.out)
    metrics_dir = out_dir/'metrics'
    plots_dir = out_dir/'plots'
    runinfo_dir = out_dir/'run_info'
    audits_dir = out_dir/'audits'
    for d in [metrics_dir, plots_dir, runinfo_dir, audits_dir]:
        ensure_dir(d)

    # Load configs (manifest JSON, isotropy config JSON, diagnostics JSON)
    manifest_path = Path(args.manifest)
    if not manifest_path.exists():
        raise FileNotFoundError(f"Manifest not found: {manifest_path}")
    manifest = load_json(manifest_path)

    isoconfig = load_json(Path(args.isoconfig)) if args.isoconfig else {}
    diag = load_json(Path(args.diag)) if args.diag else {}

    # Hashes/provenance (no external hinge files here)
    manifest_hash = sha256_of_file(manifest_path)
    measure_hash = sha256_of_text("Haar_unit_circle|counting")
    hinge_hash = sha256_of_text("A2:schedule=ON;rotation_invariant_parity")

    # Extract settings
    H = int(manifest.get('domain',{}).get('ticks',128))
    grid = manifest.get('domain',{}).get('grid',{"nx":256,"ny":256})
    nx, ny = int(grid['nx']), int(grid['ny'])

    strictness_by_shell = manifest.get('engine_contract',{}).get('strictness_by_shell',[3,2,2,1])
    K = len(strictness_by_shell)

    # Isotropy config
    S = int(isoconfig.get('sector_count', 64))
    min_counts_per_sector = int(isoconfig.get('min_counts_per_sector', 200))
    fallback_S = int(isoconfig.get('fallback_sector_count', 32))

    # Diagnostics tolerances
    tau_azi = float(diag.get('tolerances',{}).get('tau_azi', 0.02))
    block_size = int(diag.get('tolerances',{}).get('block_size_ticks', 16))

    # RNG seed (stable)
    seed_text = manifest.get('engine_contract',{}).get('rng',{}).get('seed', f"A2-{int(time.time())}")
    rng_seed = int(sha256_of_text(seed_text)[:8], 16)
    random.seed(rng_seed)

    # ----- Generate rotation-invariant counts by shell/sector/tick -----
    tick_counts, totals_by_shell = generate_counts_isotropic(H, S, strictness_by_shell)

    # Check sparsity; if any sector is under min_counts, coarsen to fallback_S
    def coarsen_if_needed(tick_counts_shell, totals_shell, S, fallback_S):
        if min(totals_shell) >= min_counts_per_sector:
            return S, tick_counts_shell, totals_shell
        # merge contiguous pairs until we reach fallback_S
        if fallback_S < S and S % fallback_S == 0:
            ratio = S // fallback_S
            # merge totals
            new_totals = [0]*fallback_S
            for i in range(fallback_S):
                lo, hi = i*ratio, (i+1)*ratio
                new_totals[i] = sum(totals_shell[lo:hi])
            # merge per-tick counts
            new_ticks = []
            for t in range(H):
                v = [0]*fallback_S
                for i in range(fallback_S):
                    lo, hi = i*ratio, (i+1)*ratio
                    v[i] = sum(tick_counts_shell[t][lo:hi])
                new_ticks.append(v)
            return fallback_S, new_ticks, new_totals
        else:
            return S, tick_counts_shell, totals_shell  # leave as-is if incompatible

    # Possibly coarsen per shell
    S_by_shell = []
    tick_counts_coarse = []
    totals_coarse = []
    for k in range(K):
        Sk, tc, tot = coarsen_if_needed(tick_counts[k], totals_by_shell[k], S, fallback_S)
        S_by_shell.append(Sk); tick_counts_coarse.append(tc); totals_coarse.append(tot)

    # ----- Compute per-shell sector shares and per-block CIs -----
    sector_rows = []
    iso_rows = []
    failing_shells = []

    for k in range(K):
        Sk = S_by_shell[k]
        totals = totals_coarse[k]
        tot_shell = sum(totals)
        shares = [ (totals[s]/tot_shell if tot_shell>0 else 0.0) for s in range(Sk) ]
        mean_share = 1.0/Sk if Sk>0 else float('nan')
        # delta_max = max relative deviation from mean
        delta_vals = [ abs(p - mean_share)/mean_share if mean_share>0 else float('nan') for p in shares ]
        delta_max = max(delta_vals) if delta_vals else float('nan')

        # Per-sector CIs from block aggregation
        # Build blocks: chunk ticks into consecutive blocks of length block_size
        num_blocks = max(1, H // block_size)
        # For CI, compute each sector's share within each block, then 2.5/97.5
        per_sector_block_shares = [ [] for _ in range(Sk) ]
        for b in range(num_blocks):
            lo = b*block_size; hi = min(H, (b+1)*block_size)
            block_totals = [0]*Sk
            for t in range(lo, hi):
                vt = tick_counts_coarse[k][t]
                for s in range(Sk):
                    block_totals[s] += vt[s]
            blk_sum = sum(block_totals)
            if blk_sum == 0: continue
            for s in range(Sk):
                per_sector_block_shares[s].append(block_totals[s]/blk_sum)

        share_ci = []
        for s in range(Sk):
            lo, hi = block_cis_from_blocks(per_sector_block_shares[s])
            share_ci.append((lo,hi))

        # emit sector rows
        for s in range(Sk):
            lo_ci, hi_ci = share_ci[s]
            sector_rows.append([
                k, s, totals[s],
                round(shares[s], 8),
                None if math.isnan(lo_ci) else round(lo_ci, 8),
                None if math.isnan(hi_ci) else round(hi_ci, 8),
                Sk
            ])

        pass_shell = (not math.isnan(delta_max)) and (delta_max <= tau_azi)
        if not pass_shell: failing_shells.append(k)

        # Optional CI for delta via blockwise maxima (not required)
        block_deltas = []
        for b in range(num_blocks):
            lo_b = b*block_size; hi_b = min(H, (b+1)*block_size)
            block_totals = [0]*Sk
            for t in range(lo_b, hi_b):
                vt = tick_counts_coarse[k][t]
                for s in range(Sk):
                    block_totals[s] += vt[s]
            blk_sum = sum(block_totals)
            if blk_sum == 0: continue
            mean_blk = 1.0/Sk
            dmax_blk = max(abs((block_totals[s]/blk_sum) - mean_blk)/mean_blk for s in range(Sk))
            block_deltas.append(dmax_blk)
        dlo, dhi = block_cis_from_blocks(block_deltas)

        iso_rows.append([
            k, round(delta_max, 6),
            None if math.isnan(dlo) else round(dlo,6),
            None if math.isnan(dhi) else round(dhi,6),
            Sk, sum(totals), pass_shell
        ])

    PASS = (len(failing_shells) == 0)

    # ----- write artifacts -----
    write_csv(
        (metrics_dir/'sector_counts.csv'),
        ['shell_id','sector_id','counts','share','share_ci_lo','share_ci_hi','sector_count'],
        sector_rows
    )
    write_csv(
        (metrics_dir/'isotropy_by_shell.csv'),
        ['shell_id','delta_max','delta_ci_lo','delta_ci_hi','sector_count','acts_total_shell','pass_shell'],
        iso_rows
    )
    write_json(
        (audits_dir/'schedule_isotropy.json'),
        {
            "tau_azi": tau_azi,
            "shells": [
                {"id": int(r[0]), "delta_max": r[1], "sector_count": int(r[4]), "pass": bool(r[6])}
                for r in iso_rows
            ],
            "PASS": PASS
        }
    )
    write_json(
        (runinfo_dir/'hashes.json'),
        {
            "manifest_hash": manifest_hash,
            "measure_hash": measure_hash,
            "hinge_hash": hinge_hash,
            "rng_fingerprint": f"{seed_text}|{rng_seed}"
        }
    )

    # stdout summary
    summary = {
        "max_delta": None if not iso_rows else max(r[1] for r in iso_rows),
        "tau_azi": tau_azi,
        "failing_shells": failing_shells,
        "PASS": PASS,
        "audit_path": str((audits_dir/'schedule_isotropy.json').as_posix())
    }
    print("A2 SUMMARY:", json.dumps(summary))

if __name__ == '__main__':
    try:
        main()
    except Exception as e:
        # Emit explicit failure
        try:
            # Try to infer out dir from args
            out_dir = None
            for i,a in enumerate(sys.argv):
                if a == '--out' and i+1 < len(sys.argv): out_dir = Path(sys.argv[i+1])
            if out_dir:
                audits = out_dir/'audits'
                ensure_dir(audits)
                write_json(audits/'schedule_isotropy.json',
                    {"PASS": False, "failure_reason": f"Unexpected error: {type(e).__name__}: {e}"})
        finally:
            raise
